# -*- coding: utf-8 -*-
"""
Created on Sun Sep 21 11:06:59 2025

@author: baran
"""

import numpy as np
import pickle
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from backpack import extend
from src.prompts.prompt_maker import input_maker
from src.embedding.embed_tele import get_context
from src.regrets.optimal_rand_seq_tele import opt_eval
from src.regrets.sum_call_seq import get_summary

# ─── STEP 1: Telecom dataset ────────────────────────────────────────────────────
input_reports, labels, explanations = input_maker("seq", "telecom","")
dataset = "telecom"

# ─── STEP 2: Description arrays ─────────────────────────────────────────────────
summary_description_array = [
    "Summarize the telecommunications question and its options concisely for analysis.",
    "Provide a brief recap of the telecom question and choices for researchers.",
    "You will take the role of a telecom-specialist summarizer. Summarize the question and answer options.",
    "Produce a short summary of the telecom question and all choices.",
    "Present the telecom question and its multiple-choice options in a concise summary."
]

diagnosis_description_array = [
    "Answer the telecom MCQ strictly 'option {i}' for this question.",
    "Provide the MCQ answer (1–4) for this telecom question.",
    "Output the telecom MCQ response as 'option {i}'.",
    "Select the correct option (1–4) for the telecommunications question.",
    "Choose the telecom MCQ answer and output 'option {i}'."
]

explanation_description_array = [
    "Explain in detail why the chosen telecom MCQ answer is correct.",
    "Provide a step-by-step rationale for why the selected answer is correct.",
    "As a telecom expert, justify why the chosen MCQ option is right.",
    "Offer a clear explanation for why the selected telecom answer is correct.",
    "Give a detailed rationale for why the chosen option is correct."
]

# ─── STEP 3: Deployment instructions per arm ────────────────────────────────────
documents = summary_description_array+ diagnosis_description_array+ explanation_description_array+ list(input_reports)

# ─── STEP 5: Deployment instructions per arm ───────────────────────────────────────────
deployments_summarizer = {
    "base"            : ("gpt-35-turbo", "You are to summarize a telecom question and its options."),
    "assistants"      : ("Assistant",     "You are to summarize a telecom question and its options."),
    "finetune_med"    : ("Med",           "You are to summarize a telecom question and its options."),
    "finetune_tele"   : ("Tele",          "You are to summarize a telecom question and its options."),
    "finetune_med_new": ("Med_New",       "You are to summarize a telecom question and its options."),
    "llama"           : ("llama",         "You are to summarize a telecom question and its options."),
}

deployments_diagnoser = {
    "base"            : ("gpt-35-turbo",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}' where i∈{1,2,3,4}."),
    "finetune_med"    : ("Med",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "finetune_tele"   : ("Tele",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "finetune_med_new": ("Med_New",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "llama"           : ("llama",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'.")
}

deployments_explainer = {
    "base"            : ("gpt-35-turbo", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_med"    : ("Med", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_tele"   : ("Tele", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_med_new": ("Med_New", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "llama"           : ("llama", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale.")
}

# ─── STEP 4: Cost-per-token dictionaries ────────────────────────────────────────
cost_per_token = {
    "base"            : 0.0000015,
    "assistants"      : 0.0000015,
    "finetune_med"    : 0.00001,
    "finetune_tele"   : 0.00001,
    "finetune_med_new": 0.00001,
    "llama"           : 0.00000071
}

input_cost_per_token = {
    "base"            : 0.0000005,
    "assistants"      : 0.0000005,
    "finetune_med"    : 0.00000025,
    "finetune_tele"   : 0.00000025,
    "finetune_med_new": 0.00000025,
    "llama"           : 0.00000071
}
# ─── STEP 5: Token-length predictor ─────────────────────────────────────────────
from transformers import AutoConfig, AutoTokenizer
import json
from tok_length_predict import BertRegressionModel
import tiktoken

reg_model_name = "bert-base-uncased"
reg_config     = AutoConfig.from_pretrained(reg_model_name)
reg_tokenizer  = AutoTokenizer.from_pretrained(reg_model_name)
with open("model_names.json") as f:
    orig_model_names = json.load(f)
num_models = len(orig_model_names)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

token_length_model = BertRegressionModel(
    reg_config, reg_model_name,
    hidden_dim=128,
    num_models=num_models
).to(device)
token_length_model.load_state_dict(torch.load("best_length_model.pth", map_location=device))
token_length_model.eval()

from sentence_transformers import SentenceTransformer
inp_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")

import tiktoken
openai_models = {"gpt-3.5-turbo", "gpt-4"}
encodings = { m: tiktoken.encoding_for_model(m) for m in openai_models }
from transformers import AutoTokenizer as HFTokenizer
try:
    llama_tok = HFTokenizer.from_pretrained("openlm-research/open_llama_13b")
except Exception:
    llama_tok = reg_tokenizer


arm_to_llm = {
        "base"            : "gpt-3.5-turbo",
        "assistants"      : "gpt-3.5-turbo",
        "finetune_med"    : "gpt-4",
        "finetune_tele"   : "gpt-4",
        "finetune_med_new": "gpt-4",
        "llama"           : "llama-13b"
    }

arm_encoders = {}
for mk, llm_name in arm_to_llm.items():
    if llm_name in encodings:
        arm_encoders[mk] = encodings[llm_name]
    else:
        arm_encoders[mk] = llama_tok

# ─── STEP 6: NeuralUCB Bandit ──────────────────────────────────────────────────
class NeuralUCBDiag:
    def __init__(self, style, dim, lamdba=1, nu=1, hidden=100):
        self.device = device
        self.net    = extend(nn.Sequential(nn.Linear(dim, hidden), nn.ReLU(), nn.Linear(hidden,1)).to(self.device))
        self.lamdba = lamdba
        self.nu     = nu
        p_count     = sum(p.numel() for p in self.net.parameters())
        self.U      = lamdba * torch.ones(p_count, device=self.device)
        self.contexts = []
        self.rewards  = []
        self.style    = style
    def selection(self, context, style):
        x = torch.from_numpy(context).float().to(self.device).unsqueeze(0)
        mu = self.net(x)
        self.net.zero_grad(); mu.backward(retain_graph=True)
        grads = torch.cat([p.grad.flatten() for p in self.net.parameters()])
        sigma = torch.sqrt(torch.sum((self.lamdba*self.nu*grads*grads/self.U)))
        score = (0.2*mu.item()+2*sigma.item()) if style=='ucb' else torch.normal(1.0*mu.view(-1),0.05*sigma.view(-1)).item()
        self.U += grads*grads
        return score
    def train(self, context, reward):
        c = torch.from_numpy(context).float().to(self.device).unsqueeze(0)
        self.contexts.append(c); self.rewards.append(float(reward))
        optimizer = optim.SGD(self.net.parameters(), lr=1e-4, weight_decay=self.lamdba)
        tot_loss = 0; cnt=0
        for ctx, r in zip(self.contexts, self.rewards):
            optimizer.zero_grad()
            pred = self.net(ctx).view(-1)[0]
            loss = (pred-r)**2
            loss.backward(); optimizer.step()
            tot_loss += loss.item(); cnt+=1
            if cnt>=5: break
        return tot_loss/cnt if cnt>0 else 0

# ─── STEP 7: Args ─────────────────────────────────────────────────────────────
parser = argparse.ArgumentParser()
parser.add_argument('--size', default=100, type=int, help='number of rounds')
parser.add_argument('--nu', type=float, default=1, metavar='v', help='nu for control variance')
parser.add_argument('--lamdba', type=float, default=1, metavar='l', help='lambda for regularization')
parser.add_argument('--hidden', type=int, default=50, help='network hidden size')
parser.add_argument('--style', default='ts', metavar='ts|ucb', help='TS or UCB')
parser.add_argument('--number_tasks', default=3, type=int, help='number of subtasks')
parser.add_argument('--no_runs', default=5, type=int, help='how many independent runs')
parser.add_argument('--alpha', default=10, type=int, help='cost accuracy tradeoff weight')
args = parser.parse_args()
size, nu, lamdba, hidden, style, number_tasks, no_runs, alpha = (
    args.size, args.nu, args.lamdba, args.hidden, args.style,
    args.number_tasks, args.no_runs, args.alpha
)
num_rounds = size

# ─── STEP 8: Prepare models & containers ───────────────────────────────────────
models_summarizer = list(deployments_summarizer.keys())
models_diagnoser  = list(deployments_diagnoser.keys())
models_explainer  = list(deployments_explainer.keys())
all_regrets, all_rewards, all_costs = [], [], []
super_arms = [(s, d, e) for s in models_summarizer for d in models_diagnoser for e in models_explainer]
num_triplets = len(super_arms)
all_plays = np.zeros((args.no_runs,num_triplets))
all_avg_arrays = []


# ─── RUN SIMULATIONS ───────────────────────────────────────────────────────────
for run in range(args.no_runs):
    print(f"=== Run {run+1}/{args.no_runs} ===")
    #u_sum = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    #u_diag = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    #u_exp  = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    #u_joint = NeuralUCBDiag(args.style, 1152, args.lamdba, args.nu, args.hidden)
    u_s = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    u_d = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    u_e = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    actual_total_cost = 0


    plays_triplet = np.zeros(num_triplets, dtype=int)
    #super_arms = [(s, d, e) for s in models_summarizer for d in models_diagnoser for e in models_explainer]
    #num_triplets = len(super_arms)
    # plays_s = np.zeros(len(deployments_summarizer), int)
    #plays_d = np.zeros(len(deployments_diagnoser), int)
    #plays_e = np.zeros(len(deployments_explainer), int)
    regrets, rewards, costs = [], [], []
    tot_reward = 0; cum_reg = 0
    avg_array = {"gpt-35-turbo":0,"Med":0,"Tele":0,"Med_New":0,"llama":0}
    i = 0
    documents = (
        summary_description_array
        + diagnosis_description_array
        + explanation_description_array
        + list(input_reports)
    )

    all_rewards_sum = []
    all_rewards_debate = []
    all_rewards_diag = []

    arm_to_llm = {
        "base"            : "gpt-3.5-turbo",
        "assistants"      : "gpt-3.5-turbo",
        "finetune_med"    : "gpt-4",
        "finetune_tele"   : "gpt-4",
        "finetune_med_new": "gpt-4",
        "llama"           : "llama-13b"
    }
    
    for t in range(args.size):
        print(f"Round {t+1}")
        question = input_reports[t]
        toks_q = reg_tokenizer(question, truncation=True, padding="max_length", max_length=256, return_tensors="pt").to(device)
        triplet_contexts_s, triplet_contexts_d,triplet_contexts_e,triplet_scores, pred_costs = [], [], [],[],[]
        for (s_arm, d_arm, e_arm) in super_arms:
            j_s = models_summarizer.index(s_arm)
            j_d = models_diagnoser.index(d_arm)
            j_e = models_explainer.index(e_arm)
        
            cont_s = get_context(documents, t, 0, j_s, len(summary_description_array), len(diagnosis_description_array), 0, inp_model, dataset)
            cont_d = get_context(documents, t, 1, j_d, len(summary_description_array), len(diagnosis_description_array), 0, inp_model, dataset)
            cont_e = get_context(documents, t, 2, j_e, len(summary_description_array), len(diagnosis_description_array), 0, inp_model, dataset)
            #ctx = np.concatenate([cont_s, cont_d, cont_e], axis=-1)
            #ctx = cont_s
            #triplet_contexts.append(ctx)
            triplet_contexts_s.append(cont_s)
            triplet_contexts_d.append(cont_d)
            triplet_contexts_e.append(cont_e)
            
            enc_s = arm_encoders[s_arm]
            in_len_sum = len(enc_s.encode(question)) if hasattr(enc_s, "encode") else len(enc_s(question, truncation=True)["input_ids"])
            # summary output pred
            idx_s = orig_model_names.index(arm_to_llm[s_arm])
            onehot_s = torch.zeros(num_models, device=device); onehot_s[idx_s]=1.0
            with torch.no_grad():
                out_len_sum_pred = token_length_model(toks_q["input_ids"], toks_q["attention_mask"], onehot_s.unsqueeze(0)).item()
            in_len_diag_est = int(round(out_len_sum_pred))
            out_len_diag_pred = 3
            
            in_len_exp_est = in_len_sum + 4
            idx_e = orig_model_names.index(arm_to_llm[e_arm])
            onehot_e = torch.zeros(num_models, device=device); onehot_e[idx_e]=1.0
            with torch.no_grad():
                out_len_exp_pred = token_length_model(toks_q["input_ids"], toks_q["attention_mask"], onehot_e.unsqueeze(0)).item()
        
            pred_cost = (
                input_cost_per_token[s_arm]*in_len_sum + cost_per_token[s_arm]*out_len_sum_pred +
                input_cost_per_token[d_arm]*in_len_diag_est + cost_per_token[d_arm]*out_len_diag_pred +
                input_cost_per_token[e_arm]*in_len_exp_est + cost_per_token[e_arm]*out_len_exp_pred
            )
            
            # val = u_joint.selection(ctx, args.style)
            # val_f = val.item() if torch.is_tensor(val) else float(val)
            # triplet_scores.append(val_f - args.alpha * pred_cost)
            v_s = u_s.selection(cont_s, args.style)
            v_d = u_d.selection(cont_d, args.style)
            v_e = u_e.selection(cont_e, args.style)
            v_s = float(v_s) if torch.is_tensor(v_s) else v_s
            v_d = float(v_d) if torch.is_tensor(v_d) else v_d
            v_e = float(v_e) if torch.is_tensor(v_e) else v_e
            
            triplet_scores.append( (v_s + v_d + v_e) - args.alpha * pred_cost )
            pred_costs.append(pred_cost)
        best_idx = int(np.argmax(triplet_scores)) if args.style=='ucb' else \
           int(np.random.choice(np.flatnonzero(np.array(triplet_scores)==np.max(triplet_scores))))
        s_arm, d_arm, e_arm = super_arms[best_idx]
        plays_triplet[best_idx] += 1
        print(f"[Round {t+1}] Selected triplet -> "
      f"Summarizer: {s_arm} | Diagnoser: {d_arm} | Explainer: {e_arm}")
        
        summary = get_summary(question, s_arm, "tele")
        summary_clean = summary.replace("\n","")
        enc_s = arm_encoders[s_arm]
        in_len_sum_actual = len(enc_s.encode(question)) if hasattr(enc_s,"encode") else len(enc_s(question, truncation=True)["input_ids"])
        out_len_sum_actual = int(round(out_len_sum_pred))  # fallback if you don't evaluate summaries explicitly
        prompt_d = summary_clean + " Please give the correct option in the format: option [correct option number]."
        reg1, reward1, out_len_diag_actual, avg_array, _, _ = opt_eval(
            deployments_diagnoser, prompt_d, "diagnosis",
            d_arm, avg_array, t, [], [], labels, dataset
        )
        enc_d = arm_encoders[d_arm]
        in_len_diag_actual = len(enc_d.encode(prompt_d)) if hasattr(enc_d,"encode") else len(enc_d(prompt_d, truncation=True)["input_ids"])
        answer_text = f"option {reward1}" if isinstance(reward1,(int,str)) else "option 1"
        prompt_e = question + " Answer chosen: " + str(answer_text)
        reg2, reward2, out_len_exp_actual, avg_array, _, _ = opt_eval(
            deployments_explainer, prompt_e, "explanation",
            e_arm, avg_array, t, [], [], explanations, dataset
        )
        enc_e = arm_encoders[e_arm]
        in_len_exp_actual = len(enc_e.encode(prompt_e)) if hasattr(enc_e,"encode") else len(enc_e(prompt_e, truncation=True)["input_ids"])

        actual_total_cost += (
            input_cost_per_token[s_arm]*in_len_sum_actual + cost_per_token[s_arm]*out_len_sum_actual +
            input_cost_per_token[d_arm]*in_len_diag_actual + cost_per_token[d_arm]*out_len_diag_actual +
            input_cost_per_token[e_arm]*in_len_exp_actual + cost_per_token[e_arm]*out_len_exp_actual
        )
        sum_cost_actual = (
            input_cost_per_token[s_arm]*in_len_sum_actual
          + cost_per_token[s_arm]*out_len_sum_actual
        )
        
        # Update metrics
        cum_reg += (reg1 + reg2)
        tot_reward += (reward1 + reward2)
        regrets.append(cum_reg)
        rewards.append(tot_reward)
        print(f"Reward: {tot_reward} | Regret: {cum_reg} | Actual total cost: {actual_total_cost}")
        costs.append( (input_cost_per_token[s_arm]*in_len_sum_actual + cost_per_token[s_arm]*out_len_sum_actual) +
                      (input_cost_per_token[d_arm]*in_len_diag_actual + cost_per_token[d_arm]*out_len_diag_actual) +
                      (input_cost_per_token[e_arm]*in_len_exp_actual + cost_per_token[e_arm]*out_len_exp_actual) )
        
        # Train joint learner
        #target = (reward1 + reward2) - args.alpha * actual_total_cost
        #loss = u_joint.train(triplet_contexts[best_idx], reward1+reward2)
        reward1_norm = float(reward1)
        reward2_norm = float(reward2)
        
        # Targets per role (recommended)
        target_s = - args.alpha * sum_cost_actual           # or 0.0 if you prefer no cost-awareness
        target_d = reward1_norm 
        target_e = reward2_norm 
        
        # Train on the chosen super-arm's contexts
        _ = u_s.train(triplet_contexts_s[best_idx], 0)
        _ = u_d.train(triplet_contexts_d[best_idx], target_d)
        _ = u_e.train(triplet_contexts_e[best_idx], target_e)
        # if (t+1) % 5 == 0:
        #     print('{}: {:.3f}, {:.3f}, {:.3f}'.format(t+1, summ, rewards, loss))
    
    all_regrets.append(regrets)
    all_rewards.append(rewards)
    all_costs.append(costs)
    all_plays[run,:] = plays_triplet
    #all_plays_d.append(plays_d)
    #all_plays_e.append(plays_e)
    all_avg_arrays.append(avg_array.copy())

import pandas as pd
avg_df      = pd.DataFrame(all_avg_arrays)
avg_mean = avg_df.mean(axis=0).to_dict()
avg_std  = avg_df.std(axis=0).to_dict()
plays_mean = all_plays.mean(axis=0)
# ─── STEP 9: Save metrics ───────────────────────────────────────────────────────
pickle.dump(np.mean(all_regrets,axis=0),open("regrets_mean_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(np.std(all_regrets,axis=0), open("regrets_std_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(np.mean(all_rewards,axis=0),open("rewards_mean_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(np.std(all_rewards,axis=0), open("rewards_std_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(np.mean(all_costs,axis=0),   open("costs_mean_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(np.std(all_costs,axis=0),    open("costs_std_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(plays_mean,open("plays_neucb_joint_budgeted_tele_1.pkl","wb"))
print(f"Final mean regret: {np.mean(all_regrets,axis=0)[-1]}")
print(f"Final mean reward: {np.mean(all_rewards,axis=0)[-1]}")
print(f"Final mean cost: {np.mean(all_costs,axis=0)[-1]}")
#print(f"Final mean summarizer cost: {costs_summarizer_mean[-1]}")
print(f"Final mean plays: {plays_mean}")
#pickle.dump(all_plays_d,open("plays_d_neucb_budgeted.pkl","wb"))
#pickle.dump(all_plays_e,open("plays_e_neucb_budgeted.pkl","wb"))
# pickle.dump(avg_mean,     open("avg_accuracy_mean_tele_budgeted_neucb.pkl","wb"))
# pickle.dump(avg_std,      open("avg_accuracy_std_tele_budgeted_neucb.pkl","wb"))

print("All runs complete. Summary pickles written.")
